import torch
from matplotlib import pyplot as plt

from aslp_ai.utils.class_utils import ACTIVATION_CLASSES

def show_all_activate_func():
    """"""
    for key, value in ACTIVATION_CLASSES.items():
        print(key, value.__name__)
        # if key=='relu':
        plot_activation(value)


def plot_activation(activation_func, x_range=(-25, 25), num_points=1000):
    x = torch.linspace(x_range[0], x_range[1], num_points)
    y = activation_func()(x)

    plt.figure(figsize=(6, 4))
    plt.plot(x.numpy(), y.numpy(), label=activation_func.__name__)
    plt.title(f"{activation_func.__name__} Activation Function")
    plt.xlabel('x')
    plt.ylabel('f(x)')
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    # 保存图像
    plt.savefig(f"./fig/{activation_func.__name__}.png")
    plt.close()


if __name__ == '__main__':
    show_all_activate_func()